Skip to content

fix: preserve MDP integrity in PPO mini-batching#98

Open
lqzxt wants to merge 1 commit into
AgentR1:mainfrom
lqzxt:fix-preserve-trajectory-mini-batches
Open

fix: preserve MDP integrity in PPO mini-batching#98
lqzxt wants to merge 1 commit into
AgentR1:mainfrom
lqzxt:fix-preserve-trajectory-mini-batches

Conversation

@lqzxt
Copy link
Copy Markdown
Collaborator

@lqzxt lqzxt commented Jun 4, 2026

Summary

Preserve MDP integrity during PPO updates by ensuring all reasoning steps from the same trajectory are assigned to the same mini-batch. This prevents a single trajectory from being split across different actor or critic update batches, which would break the consistency of trajectory-level MDP optimization.

Changes

  • Add trajectory-level mini-batch preparation to keep all steps from the same trajectory in one PPO mini-batch.
  • Attach mini-batch IDs, sample masks, and global batch/token metadata to planned update batches.
  • Update actor and critic workers to split by precomputed trajectory mini-batch IDs when available.
  • Pass global mini-batch statistics into policy, entropy, KL, and value loss aggregation.
  • Support explicit mini-batch counts in the training worker dispatch path.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces trajectory-aware PPO mini-batching helpers to preserve whole trajectories across mini-batches. It updates the Ray trainer, DP actor, DP critic, and loss computation functions to support planned mini-batches and propagate global batch metadata (such as DP size, global batch size, and token counts) for proper loss aggregation. The review feedback highlights two critical improvements in trajectory_batching.py: first, handling empty valid indices defensively to prevent pipeline crashes, and second, vectorizing token sum calculations on the GPU to avoid performance bottlenecks caused by row-by-row GPU-CPU synchronizations.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +37 to +39
valid_indices = _valid_indices(data)
if not valid_indices:
raise ValueError("trajectory mini-batching requires at least one valid row")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If valid_indices is empty (e.g., when all samples in the batch are masked out or invalid), raising a ValueError will crash the entire training pipeline. It is safer to handle this case defensively by returning a dummy mini-batch of all padding, which avoids crashes while correctly marking all samples as invalid.

Suggested change
valid_indices = _valid_indices(data)
if not valid_indices:
raise ValueError("trajectory mini-batching requires at least one valid row")
valid_indices = _valid_indices(data)
if not valid_indices:
device = _batch_device(data.batch)
prepared = data.select_idxs(list(range(len(data))))
prepared.batch["mini_batch_id"] = torch.zeros(len(data), dtype=torch.long, device=device)
prepared.batch["sample_mask"] = torch.zeros(len(data), dtype=torch.bool, device=device)
prepared.batch["mini_batch_global_size"] = torch.zeros(len(data), dtype=torch.long, device=device)
prepared.batch["mini_batch_global_token_num"] = torch.zeros((len(data), 1), dtype=torch.long, device=device)
prepared.batch["mini_batch_global_response_token_num"] = torch.zeros(len(data), dtype=torch.long, device=device)
prepared.meta_info = dict(getattr(prepared, "meta_info", {}))
prepared.meta_info["num_mini_batch"] = 1
return prepared

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think returning a dummy mini-batch is safe here. valid_indices is derived from sample_mask, which in this flow only marks world-size padding rows. A non-empty training batch should always contain at least one real row; if all rows are invalid, that is an upstream invariant violation and failing fast is preferable.

The proposed dummy batch is not a no-op: planned mini-batches are still executed by actor/critic/engine workers, and the loss path does not use sample_mask as the skip condition. The suggestion also does not clear response_mask/loss_mask, so invalid rows may still affect loss/metrics depending on the path. If we ever need to support an all-invalid batch, the trainer should skip actor/critic updates explicitly, not fabricate mini-batch metadata here.

Comment on lines +175 to +209
def _assign_global_mini_batch_info(
prepared: Any,
source_batch: Any,
mini_batches: list[list[list[int]]],
mini_batch_ids: torch.Tensor,
device: torch.device,
) -> None:
row_counts = [sum(len(group) for group in mini_batch) for mini_batch in mini_batches]
max_rows = max(row_counts)
global_sizes = torch.tensor(row_counts, dtype=torch.long, device=device)
prepared.batch["mini_batch_global_size"] = global_sizes[mini_batch_ids]

token_num_table = torch.zeros((len(mini_batches), max_rows), dtype=torch.long, device=device)
response_token_nums = torch.zeros(len(mini_batches), dtype=torch.long, device=device)
attention_mask = source_batch.get("attention_mask", None)
response_mask = source_batch.get("response_mask", None)

for mini_batch_id, mini_batch in enumerate(mini_batches):
row_indices = [row_idx for group in mini_batch for row_idx in group]
if attention_mask is not None:
source_token_nums = attention_mask.new_tensor(
[_sum_row_tokens(attention_mask, row_idx) for row_idx in row_indices], dtype=torch.long
)
token_num_table[mini_batch_id, : len(row_indices)] = source_token_nums.to(device)
if response_mask is not None:
response_token_nums[mini_batch_id] = int(
sum(_sum_row_tokens(response_mask, row_idx) for row_idx in row_indices)
)

prepared.batch["mini_batch_global_token_num"] = token_num_table[mini_batch_ids]
prepared.batch["mini_batch_global_response_token_num"] = response_token_nums[mini_batch_ids]


def _sum_row_tokens(tensor: torch.Tensor, row_idx: int) -> int:
return int(tensor[row_idx].sum().detach().cpu().item())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of _assign_global_mini_batch_info uses a loop over all rows in the mini-batch and calls _sum_row_tokens, which performs .detach().cpu().item() on each row. This causes multiple GPU-CPU synchronizations per step, creating a massive performance bottleneck. We can completely avoid this by precomputing the sum of tokens for all rows in a single vectorized operation on the GPU.

def _assign_global_mini_batch_info(
    prepared: Any,
    source_batch: Any,
    mini_batches: list[list[list[int]]],
    mini_batch_ids: torch.Tensor,
    device: torch.device,
) -> None:
    row_counts = [sum(len(group) for group in mini_batch) for mini_batch in mini_batches]
    max_rows = max(row_counts)
    global_sizes = torch.tensor(row_counts, dtype=torch.long, device=device)
    prepared.batch["mini_batch_global_size"] = global_sizes[mini_batch_ids]

    token_num_table = torch.zeros((len(mini_batches), max_rows), dtype=torch.long, device=device)
    response_token_nums = torch.zeros(len(mini_batches), dtype=torch.long, device=device)
    attention_mask = source_batch.get("attention_mask", None)
    response_mask = source_batch.get("response_mask", None)

    attention_token_counts = attention_mask.sum(dim=-1) if attention_mask is not None else None
    response_token_counts = response_mask.sum(dim=-1) if response_mask is not None else None

    for mini_batch_id, mini_batch in enumerate(mini_batches):
        row_indices = [row_idx for group in mini_batch for row_idx in group]
        if attention_token_counts is not None:
            source_token_nums = attention_token_counts[row_indices]
            token_num_table[mini_batch_id, : len(row_indices)] = source_token_nums
        if response_token_counts is not None:
            response_token_nums[mini_batch_id] = response_token_counts[row_indices].sum()

    prepared.batch["mini_batch_global_token_num"] = token_num_table[mini_batch_ids]
    prepared.batch["mini_batch_global_response_token_num"] = response_token_nums[mini_batch_ids]

@lqzxt lqzxt requested a review from 0russwest0 June 6, 2026 18:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant